/*
 * File:        computeExtensionFields2d.c
 * Copyright:   (c) 2005-2006 Kevin T. Chu
 * Revision:    $Revision: 1.14 $
 * Modified:    $Date: 2006/09/18 16:17:02 $
 * Description: MATLAB MEX-file for using the fast marching method to
 *              compute extension fields for 2d level set functions
 */

/*===========================================================================
 *
 * computeExtensionFields2d() computes a distance from an 
 * arbitrary level set function using the Fast Marching Method.
 * 
 * Usage: [distance_function, extension_fields] = ...
 *        computeExtensionFields2d(phi, source_fields, dX, ...
 *                                 mask, ...
 *                                 spatial_derivative_order)
 *
 * Arguments:
 * - phi:                       level set function to use in 
 *                                computing distance function
 * - source_fields:             field variables that are to
 *                                be extended off of the zero
 *                                level set
 * - dX:                        array containing the grid spacing
 *                                in each coordinate direction
 * - mask:                      mask for domain of problem;
 *                                grid points outside of the domain
 *                                of the problem should be set to a
 *                                negative value
 *                                (default = [])
 * - spatial_derivative_order:  order of discretization for 
 *                                spatial derivatives
 *                                (default = 5)
 *
 * Return values:
 * - distance_function:         distance function
 * - extension_fields:          extension fields
 *
 * NOTES:
 * - All data arrays are assumed to be in the order generated by the
 *   MATLAB meshgrid() function.  That is, data corresponding to the
 *   point (x_i,y_j) is stored at index (j,i).
 *
 *===========================================================================*/

#include "mex.h"
#include "lsm_fast_marching_method.h" 

/* Input Arguments */
#define PHI	                   (prhs[0])
#define SOURCE_FIELDS              (prhs[1])
#define DX                         (prhs[2])
#define MASK                       (prhs[3])
#define SPATIAL_DERIVATIVE_ORDER   (prhs[4])


/* Output Arguments */
#define DISTANCE_FUNCTION          (plhs[0])
#define EXTENSION_FIELDS           (plhs[1])


void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
  /* field data */
  double *phi;
  double *mask;
  double **source_fields;
  double *distance_function;
  double **extension_fields;
  int num_ext_fields;
  double max_value;
 
  /* grid data */
  const int *grid_dims = mxGetDimensions(PHI);
  double *dX = mxGetPr(DX);
  double dX_matlab_order[2];

  /* numerical parameters */
  int spatial_derivative_order;

  /* auxilliary variables */
  int i;
  int error_code;
  mxArray* tmp_mxArray;

  /* Check for proper number of arguments */
  if (nrhs < 3) {
    mexErrMsgTxt(
      "Insufficient number of input arguments (3 required; 2 optional)");
  } else if (nrhs > 5) {
    mexErrMsgTxt("Too many input arguments (3 required; 2 optional)");
  } else if (nlhs > 2) {
    mexErrMsgTxt("Too many output arguments.");
  }

  /* Get mask */
  if (nrhs < 4) {
    mask = NULL;  /* NULL mask ==> all points are in interior of domain */
    max_value = -1;
  } else {
    int mrows = mxGetM(MASK);
    int ncols = mxGetN(MASK);
    if (mrows == 1 && ncols == 1) 
    {
      max_value = mxGetScalar(MASK);
      mask = NULL;
    }
    else
    {
      max_value = -1;
      mask = mxGetPr(MASK);
    }
  }

  /* Get spatial derivative order */
  if (nrhs < 5) {
    spatial_derivative_order = 1;  /* KTC - change this to 2 after */
                                   /* implementing second-order algorithm */
  } else {
    spatial_derivative_order = (int)(mxGetPr(SPATIAL_DERIVATIVE_ORDER)[0]);
  }

  /* Assign pointers for phi and extension field data */
  phi = mxGetPr(PHI);
  num_ext_fields = mxGetNumberOfElements(SOURCE_FIELDS);
  source_fields = (double**) malloc(num_ext_fields*sizeof(double*)); 
  for (i = 0; i < num_ext_fields; i++) {
    tmp_mxArray = mxGetCell(SOURCE_FIELDS,i);
    source_fields[i] = mxGetPr(tmp_mxArray);
  }


  /* Create distance function and extension field data */
  DISTANCE_FUNCTION = mxCreateDoubleMatrix(grid_dims[0], grid_dims[1], mxREAL);
  distance_function = mxGetPr(DISTANCE_FUNCTION);
  num_ext_fields = mxGetNumberOfElements(SOURCE_FIELDS);
  EXTENSION_FIELDS = mxCreateCellArray(1, &num_ext_fields);
  extension_fields = (double**) malloc(num_ext_fields*sizeof(double*)); 
  for (i = 0; i < num_ext_fields; i++) {
    tmp_mxArray = mxCreateDoubleMatrix(grid_dims[0], grid_dims[1], mxREAL);
    mxSetCell(EXTENSION_FIELDS, i, tmp_mxArray);
    extension_fields[i] = mxGetPr(tmp_mxArray);
  }

  /* Change order of dX to be match MATLAB meshgrid() order for grids. */
  dX_matlab_order[0] = dX[1];
  dX_matlab_order[1] = dX[0];

  /* Carry out FMM calculation */
  error_code = computeExtensionFields2d_WithMaxVal(
                 distance_function,
                 extension_fields,
                 phi,
                 mask,
                 source_fields,
                 num_ext_fields,
                 spatial_derivative_order,
                 (int*) grid_dims,
                 dX_matlab_order,
                 max_value);

  if (error_code) {
    mexErrMsgTxt("computeExtensionFields2d failed...");
  }

  /* Clean up memory */
  free(source_fields); 
  free(extension_fields);

  return;
}
